"""
Implementations of Dual FPN One FPN for predicting objects and the other for RAF.
"""
__author__ = "Hengyue Liu"
__copyright__ = "Copyright (c) 2021 Futurewei Inc."
__credits__ = ["Detectron2"]
__license__ = "MIT License"
__version__ = "0.1"
__maintainer__ = "Hengyue Liu"
__email__ = "onehothenry@gmail.com"

import math

import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn.parameter import Parameter

import fvcore.nn.weight_init as weight_init
from detectron2.modeling.backbone import Backbone, build_resnet_backbone
# from detectron2.modeling.backbone.fpn import LastLevelP6P7
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
from detectron2.layers import Conv2d, FrozenBatchNorm2d, ShapeSpec, get_norm

from fcsgg.layers import Conv2dWithPadding, SeparableConv2d, MaxPool2d, Swish, MemoryEfficientSwish
# from .efficientnet import build_efficientnet_backbone
from .bifpn import BiFPN, _assert_strides_are_log2_contiguous, BeforeBiFPNLayer, BiFPNLayer

class DualBiFPN(Backbone):
    """
    This module implements Bi-Derectional Feature Pyramid Network.
    It creates pyramid features built on top of some input feature maps.
    """

    def __init__(self,
                 bottom_up,
                 in_features,
                 out_channels,
                 fpn_repeat,
                 norm="SyncBN",
                 top_block=None,
                 fuse_type="sum"):
        """
        Args:
            bottom_up (Backbone): module representing the bottom up subnetwork.
                Must be a subclass of :class:`Backbone`. The multi-scale feature
                maps generated by the bottom up network, and listed in `in_features`,
                are used to generate FPN levels.
            in_features (list[str]): names of the input feature maps coming
                from the backbone to which FPN is attached. For example, if the
                backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
                of these may be used; order must be from high to low resolution.
            out_channels (int): number of channels in the output feature maps.
            norm (str): the normalization to use.
            top_block (nn.Module or None): if provided, an extra operation will
                be performed on the output of the last (smallest resolution)
                FPN output, and the result will extend the result list. The top_block
                further downsamples the feature map. It must have an attribute
                "num_levels", meaning the number of extra FPN levels added by
                this block, and "in_feature", which is a string representing
                its input feature (e.g., p5).
            fuse_type (str): types for fusing the top down features and the lateral
                ones. It can be "sum" (default), which sums up element-wise; or "avg",
                which takes the element-wise mean of the two.
        """
        super(DualBiFPN, self).__init__()
        assert isinstance(bottom_up, Backbone)

        # Feature map strides and channels from the bottom up network (e.g. ResNet)
        in_strides = [bottom_up._out_feature_strides[f] for f in in_features]
        in_channels = [bottom_up._out_feature_channels[f] for f in in_features]

        self.in_features = in_features
        self.bottom_up = bottom_up

        # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
        self._out_feature_strides = {f"p{int(math.log2(s))}_0": s for s in in_strides}
        self._out_feature_strides.update({f"p{int(math.log2(s))}_1": s for s in in_strides})
        # top block output feature maps.
        last_stage = int(math.log2(in_strides[-1]))
        extra_levels = top_block.num_levels if top_block else 0
        for s in range(last_stage, last_stage + extra_levels):
            in_strides.append(2 ** (s + 1))
            in_channels.append(out_channels)
            self._out_feature_strides[f"p{s + 1}_0"] = 2 ** (s + 1)
            self._out_feature_strides[f"p{s + 1}_1"] = 2 ** (s + 1)

        _assert_strides_are_log2_contiguous(in_strides)
        self.before_bifpn_0 = BeforeBiFPNLayer(out_channels, in_channels, top_block=top_block, norm=norm)
        self.before_bifpn_1 = BeforeBiFPNLayer(out_channels, in_channels, kernel_size=3, top_block=top_block, norm=norm)
        layers = []
        num_stages = self.before_bifpn_0.num_stages
        for i in range(fpn_repeat):
            lateral = True if i == 0 else False
            if i > 0: lateral = False
            layers.append(BiFPNLayer(out_channels, num_stages, norm=norm, lateral=lateral))
        self.bifpn_0 = nn.Sequential(*layers)
        layers = []
        num_stages = self.before_bifpn_1.num_stages
        for i in range(fpn_repeat):
            lateral = True if i == 0 else False
            if i > 0: lateral = False
            layers.append(BiFPNLayer(out_channels, num_stages, norm=norm, lateral=lateral))
        self.bifpn_1 = nn.Sequential(*layers)

        self._out_features = list(self._out_feature_strides.keys())
        self._out_feature_channels = {k: out_channels for k in self._out_features}
        self._size_divisibility = self._out_feature_strides[self._out_features[-1]]

    @property
    def size_divisibility(self):
        return self._size_divisibility

    def forward(self, x):
        """
        Args:
            input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
                feature map tensor for each feature level in high to low resolution order.

        Returns:
            dict[str->Tensor]:
                mapping from feature map name to FPN feature map tensor
                in high to low resolution order. Returned feature names follow the FPN
                paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
                ["p2", "p3", ..., "p6"].
        """
        # Reverse feature maps into top-down order (from low to high resolution)
        bottom_up_features = self.bottom_up(x)
        features = [bottom_up_features[f] for f in self.in_features]

        lateral_features_0, skip_features_0 = self.before_bifpn_0(features)
        features_0, _ = self.bifpn_0((lateral_features_0, skip_features_0))
        lateral_features_1, skip_features_1 = self.before_bifpn_1(features)
        features_1, _ = self.bifpn_1((lateral_features_1, skip_features_1))
        features = features_0 + features_1

        assert len(self._out_features) == len(features)
        return dict(zip(self._out_features, features))

@BACKBONE_REGISTRY.register()
def build_resnet_dual_bifpn_backbone(cfg, input_shape: ShapeSpec):
    """
    Args:
        cfg: a detectron2 CfgNode
    Returns:
        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
    """
    bottom_up = build_resnet_backbone(cfg, input_shape)
    in_features = cfg.MODEL.FPN.IN_FEATURES
    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
    assert len(in_features) == 4
    backbone = DualBiFPN(bottom_up=bottom_up,
                         in_features=in_features,
                         out_channels=out_channels,
                         fpn_repeat=cfg.MODEL.FPN.REPEAT,
                         norm=cfg.MODEL.FPN.NORM,
                         top_block=None,
                         fuse_type=cfg.MODEL.FPN.FUSE_TYPE)
    return backbone